#!/usr/bin/env python3
'''
procdockerstatsd
Daemon which periodically gathers process and docker statistics and pushes the data to STATE_DB
'''

import os
import psutil
import re
import subprocess
import sys
import time
from datetime import datetime, timedelta

from sonic_py_common import daemon_base
from swsscommon import swsscommon
from sonic_py_common.general import getstatusoutput_noshell_pipe, getstatusoutput_noshell

VERSION = '1.0'

SYSLOG_IDENTIFIER = "procdockerstatsd"

REDIS_HOSTIP = "127.0.0.1"


class ProcDockerStats(daemon_base.DaemonBase):
    all_process_obj = {}

    def __init__(self, log_identifier):
        super(ProcDockerStats, self).__init__(log_identifier)
        self.state_db = swsscommon.SonicV2Connector(host=REDIS_HOSTIP)
        self.state_db.connect("STATE_DB")

    def run_command(self, cmd):
        proc = subprocess.Popen(cmd, universal_newlines=True, stdout=subprocess.PIPE)
        (stdout, stderr) = proc.communicate()
        if proc.returncode != 0:
            self.log_error("Error running command '{}'".format(cmd))
            return None
        else:
            return stdout

    def format_docker_cmd_output(self, cmdout):
        lines = cmdout.splitlines()
        keys = re.split("   +", lines[0])
        docker_data = dict()
        docker_data_list = []
        for line in lines[1:]:
            values = re.split("   +", line)
            docker_data = {key: value for key, value in zip(keys, values)}
            docker_data_list.append(docker_data)
        formatted_dict = self.create_docker_dict(docker_data_list)
        return formatted_dict

    def format_process_cmd_output(self, cmdout):
        lines = cmdout.splitlines()
        keys = re.split(" +", lines[0])
        key_list = [key for key in keys if key]
        process_data = dict()
        process_data_list = []
        for line in lines[1:]:
            values = re.split(" +", line)
            # To remove extra space before UID
            val_list = [val for val in values if val]
            # Merging extra columns created due to space in cmd ouput
            val_list[8:] = [' '.join(val_list[8:])]
            process_data = {key: value for key, value in zip(key_list, val_list)}
            process_data_list.append(process_data)
        return process_data_list

    def convert_to_bytes(self, value):
        UNITS_B = 'B'
        UNITS_KB = 'KB'
        UNITS_MB = 'MB'
        UNITS_MiB = 'MiB'
        UNITS_GiB = 'GiB'

        res = re.match(r'(\d+\.?\d*)([a-zA-Z]+)', value)
        value = float(res.groups()[0])
        units = res.groups()[1]
        if units.lower() == UNITS_KB.lower():
            value *= 1000
        elif units.lower() == UNITS_MB.lower():
            value *= (1000 * 1000)
        elif units.lower() == UNITS_MiB.lower():
            value *= (1024 * 1024)
        elif units.lower() == UNITS_GiB.lower():
            value *= (1024 * 1024 * 1024)

        return int(round(value))

    def create_docker_dict(self, dict_list):
        dockerdict = {}
        for row in dict_list[0:]:
            cid = row.get('CONTAINER ID')
            if cid:
                key = 'DOCKER_STATS|{}'.format(cid)
                dockerdict[key] = {}
                dockerdict[key]['NAME'] = row.get('NAME')

                cpu = row.get('CPU %').split("%")
                dockerdict[key]['CPU%'] = str(cpu[0])

                memuse = row.get('MEM USAGE / LIMIT').split(" / ")
                # converting MiB and GiB to bytes
                dockerdict[key]['MEM_BYTES'] = str(self.convert_to_bytes(memuse[0]))
                dockerdict[key]['MEM_LIMIT_BYTES'] = str(self.convert_to_bytes(memuse[1]))

                mem = row.get('MEM %').split("%")
                dockerdict[key]['MEM%'] = str(mem[0])

                netio = row.get('NET I/O').split(" / ")
                dockerdict[key]['NET_IN_BYTES'] = str(self.convert_to_bytes(netio[0]))
                dockerdict[key]['NET_OUT_BYTES'] = str(self.convert_to_bytes(netio[1]))

                blockio = row.get('BLOCK I/O').split(" / ")
                dockerdict[key]['BLOCK_IN_BYTES'] = str(self.convert_to_bytes(blockio[0]))
                dockerdict[key]['BLOCK_OUT_BYTES'] = str(self.convert_to_bytes(blockio[1]))

                dockerdict[key]['PIDS'] = row.get('PIDS')
        return dockerdict

    def update_dockerstats_command(self):
        cmd = ["docker", "stats", "--no-stream", "-a"]
        data = self.run_command(cmd)
        if not data:
            self.log_error("'{}' returned null output".format(cmd))
            return False
        dockerdata = self.format_docker_cmd_output(data)
        if not dockerdata:
            self.log_error("formatting for docker output failed")
            return False
        # wipe out all data from state_db before updating
        self.state_db.delete_all_by_pattern('STATE_DB', 'DOCKER_STATS|*')
        for k1,v1 in dockerdata.items():
            self.batch_update_state_db(k1, v1)
        return True

    def update_processstats_command(self):
        processdata = []
        pid_set = set()

        processes_all = []
        for process_obj in psutil.process_iter(['pid', 'ppid', 'memory_percent', 'cpu_percent', 'create_time', 'cmdline']):
            processes_all.append(process_obj)
        sorted_processes = sorted(processes_all, key=lambda x: x.cpu_percent(), reverse=True)
        top_processes = sorted_processes[:1024]

        for process_obj in top_processes:
            try:
                pid = process_obj.pid
                pid_set.add(pid)
                # store process object and reuse for CPU utilization
                if pid in self.all_process_obj:
                    process = self.all_process_obj[pid]
                else:
                    self.all_process_obj[pid] = process_obj
                    process = process_obj
                uid = process.uids()[0]
                ppid = process.ppid()
                mem = process.memory_percent()
                cpu = process.cpu_percent()
                stime = process.create_time()
                stime_formatted = datetime.utcfromtimestamp(stime).strftime("%b%d")
                tty = process.terminal()
                ttime = process.cpu_times()
                time_formatted = str(timedelta(seconds=int(ttime.user + ttime.system)))
                cmd = ' '.join(process.cmdline())

                row = {'PID': pid, 'UID': uid, 'PPID': ppid, '%CPU': cpu, '%MEM': mem, 'STIME': stime_formatted,  'TT': tty, 'TIME': time_formatted, 'CMD': cmd}
                processdata.append(row)
            except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
                pass

        # erase dead process
        remove_keys = []
        for id in self.all_process_obj:
            if id not in pid_set:
                remove_keys.append(id)
        for id in remove_keys:
            del self.all_process_obj[id]

        # wipe out all data before updating with new values
        self.state_db.delete_all_by_pattern('STATE_DB', 'PROCESS_STATS|*')

        update_value = {}
        for row in processdata:
            cid = row.get('PID')
            if cid:
                value = 'PROCESS_STATS|{}'.format(cid)
                uid = row.get('UID')
                update_value['UID'] = str(uid)
                ppid = row.get('PPID')
                update_value['PPID'] = str(ppid)
                cpu = row.get('%CPU')
                update_value['CPU'] = str(cpu)
                mem = round(row.get('%MEM'), 1)
                update_value['MEM'] = str(mem)
                stime = row.get('STIME')
                update_value['STIME'] = str(stime)
                tty = row.get('TT')
                update_value['TT'] = str(tty)
                time = row.get('TIME')
                update_value['TIME'] = str(time)
                cmd = row.get('CMD')
                update_value['CMD'] = cmd
                self.batch_update_state_db(value, update_value)

    def update_fipsstats_command(self):
        fips_db_key = 'FIPS_STATS|state'

        # Check if FIPS enforced in the current kernel cmdline
        with open('/proc/cmdline') as f:
           kernel_cmdline = f.read().strip().split(' ')
        enforced = 'sonic_fips=1' in kernel_cmdline or 'fips=1' in kernel_cmdline

        # Check if FIPS runtime status
        exitcode, _ = getstatusoutput_noshell_pipe(['sudo', 'openssl', 'engine', '-vv'], ['grep', '-i', 'symcryp'])
        enabled = not any(exitcode)
        update_value = {}
        update_value['timestamp'] = datetime.utcnow().isoformat()
        update_value['enforced'] = str(enforced)
        update_value['enabled'] = str(enabled)
        self.batch_update_state_db(fips_db_key, update_value)

    def update_state_db(self, key1, key2, value2):
        self.state_db.set('STATE_DB', key1, key2, value2)

    def batch_update_state_db(self, key1, fvs):
        self.state_db.hmset('STATE_DB', key1, fvs)
  
    def run(self):
        self.log_info("Starting up ...")

        if not os.getuid() == 0:
            self.log_error("Must be root to run this daemon")
            print("Must be root to run this daemon")
            sys.exit(1)

        while True:
            self.update_dockerstats_command()
            datetimeobj = datetime.now()
            # Adding key to store latest update time.
            self.update_state_db('DOCKER_STATS|LastUpdateTime', 'lastupdate', str(datetimeobj))
            self.update_processstats_command()
            self.update_state_db('PROCESS_STATS|LastUpdateTime', 'lastupdate', str(datetimeobj))
            self.update_fipsstats_command()
            self.update_state_db('FIPS_STATS|LastUpdateTime', 'lastupdate', str(datetimeobj))

            # Data need to be updated every 2 mins. hence adding delay of 120 seconds
            time.sleep(120)

        self.log_info("Exiting ...")


def main():
    # Instantiate a ProcDockerStats object
    pd = ProcDockerStats(SYSLOG_IDENTIFIER)

    # Log all messages from INFO level and higher
    pd.set_min_log_priority_info()

    pd.run()


if __name__ == '__main__':
    main()
